Skip to content

gpu support#256

Open
Qfl3x wants to merge 10 commits intomainfrom
ac/gpu_support
Open

gpu support#256
Qfl3x wants to merge 10 commits intomainfrom
ac/gpu_support

Conversation

@Qfl3x
Copy link
Copy Markdown
Collaborator

@Qfl3x Qfl3x commented Apr 2, 2026

No description provided.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces GPU support and refactors data handling to separate predictors and forcings throughout the training pipeline. Key changes include adding CUDA dependencies, updating configuration objects with device selectors, and modifying data loaders, splitters, and model forward passes to accommodate a new nested tuple input structure. Feedback highlights a mathematical error in the R-squared calculation, potential shape mismatches and incorrect NaN masking in the epoch loop, and several instances of dead code or typos. Additionally, a logic error was identified in a warning check within the data preparation module.

function loss_fn(ŷ, y, y_nan, ::Val{:r2})
r = cor(ŷ[y_nan], y[y_nan])
return r * r
return 1 - sum((y[y_nan] .- ŷ[y_nan]).^2) / sum((y[y_nan] .- mean(ŷ[y_nan])).^2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The R-squared calculation is incorrect. The denominator should use the mean of the observed values (y), not the predicted values (). The standard definition of R² is $1 - SS_{res}/SS_{tot}$, where $SS_{tot}$ is calculated relative to the mean of the observations.

    return 1 - sum((y[y_nan] .- ŷ[y_nan]).^2) / sum((y[y_nan] .- mean(y[y_nan])).^2)

Comment on lines +5 to +8
is_no_nan = falses(length(first(y))) |> cfg.gdev
for vec in y
is_no_nan = is_no_nan.|| .!isnan.(vec)
end
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic has two significant issues:

  1. Shape Mismatch: falses(length(first(y))) creates a 1D array. If the targets are multi-dimensional (e.g., (time, batch)), the bitwise OR operation .|| will fail. Use size(first(y)) instead of length.
  2. Incorrect Masking: Computing a single global is_no_nan mask by ORing all targets is problematic. If target A has a NaN at an index where target B is valid, the global mask will be true at that index. Consequently, the loss for target A will be computed using the NaN value, resulting in a NaN total loss. Masks should be computed and applied per-target.

@warn "Note that you don't have target names."
end
return predictors_forcing, targets
return predictors, forcings, targets
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function now returns predictors and forcings separately, but the warning check at line 115 (visible in context) still references predictors_forcing. Since predictors_forcing is initialized as an empty array at line 89 and never populated in the new logic, this warning will be triggered on every call. The check should be updated to verify if both predictors and forcings are empty.

Comment on lines +108 to +109
y_t = y[target]# _get_target_y(y, target)
ŷ_t = ŷ[target]#_get_target_ŷ(ŷ, y_t, target)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in line 109: (y with combining circumflex) is used instead of the argument ŷ (U+0177) defined at line 105. While Julia normalizes identifiers to NFC, mixing these characters is confusing and can lead to issues in environments with different normalization rules. Additionally, the commented-out code should be removed.

                y_t = y[target]
                ŷ_t = ŷ[target]

Comment on lines +19 to +20
ps, st = LuxCore.setup(Random.default_rng(), model) |> cfg.gdev
ps = ps |> cfg.gdev
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The line ps = ps |> cfg.gdev is redundant. The parameters ps and state st are already moved to the device as part of the piped operation in line 19.

        ps, st = LuxCore.setup(Random.default_rng(), model) |> cfg.gdev

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, line 20 is not needed here.

@lazarusA lazarusA mentioned this pull request Apr 3, 2026
return (data(predictors_forcing), data(targets))
dev = cfg.gdev
targets_nt = NamedTuple([target => dev(Array(data(target))) for target in targets])
forcings_nt = NamedTuple([forcing => dev(Array(data(forcing))) for forcing in forcings])
Copy link
Copy Markdown
Member

@lazarusA lazarusA Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should do dev/ Array at the batch loader level. Up to this point data could still be lazy.

ŷ_train, _ = model(x_train, ps, LuxCore.testmode(st))
ŷ_val, _ = model(x_val, ps, LuxCore.testmode(st))
ŷ_train, _ = model((cfg.cdev(x_train), cfg.cdev(forcings_train)), cfg.cdev(ps), LuxCore.testmode(st))
ŷ_val, _ = model((cfg.cdev(x_val), cfg.cdev(forcings_val)), cfg.cdev(ps), LuxCore.testmode(st))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can evaluate this still on the GPU side and just pipe the result of into the cfg.dev function.

Comment on lines +4 to +5
for (x, y) in cfg.gdev(loader)
is_no_nan = falses(length(first(y))) |> cfg.gdev
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the cfg.gdev(loader) is already moving the data into the gpu device, the second line should be an operation on the gpu side already, hence |> cfg.gdev should not be needed, in principle

@lazarusA lazarusA mentioned this pull request Apr 7, 2026
@lazarusA
Copy link
Copy Markdown
Member

lazarusA commented Apr 7, 2026

comments are being addressed in #257

@Qfl3x
Copy link
Copy Markdown
Collaborator Author

Qfl3x commented Apr 7, 2026

Main things I changed here are:

  1. Masking is now pre-computed. Found that it was a bit expensive previously.
  2. cleaner device switching,
  3. Changing previous tests to conform to the new interface

I think the main thing that needs work on here and/or #257 is device switching, namely that device switch should happen at the batch level. CUDA needs contiguous arrays anyways to work, so views are a no no. Meaning we have to allocate a new array at every batch anyways, so may as well do the device switching there.

I worked here before I saw all the comments. So I'll now switch to #257

@lazarusA
Copy link
Copy Markdown
Member

lazarusA commented Apr 7, 2026

We can also continue here, whatever merge branches is easier. Both are still open. I just didn't wanna to have merge conflicts just in case you did local work 😌, as you have done 👍.

@lazarusA lazarusA changed the title Ac/gpu support gpu support Apr 7, 2026
@testset "_compute_loss" begin
# Test data setup
ŷ = Dict(:var1 => [1.0, 2.0, 3.0], :var2 => [2.0, 3.0, 4.0])
y(target) = target == :var1 ? [1.1, 1.9, 3.2] : [1.8, 3.1, 3.9]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doing y(target) was a intended to mirror AxisKeys syntax, although admittedly it would be better to do an independent test for that, and similarly for DD.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I've written the interface for loss functions to accept named tuples. So I may have to add a dispatch for callables.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should work again.

@lazarusA
Copy link
Copy Markdown
Member

lazarusA commented Apr 8, 2026

On regards to data:

  • csv files - > DataFrames (non-AD) -> AxisKeys (AD)
  • zarr / nc -> YAXArray (DD which now also supports AD), less tested, but examples in main with direct DD types are working.

@@ -2,6 +2,10 @@ function run_epoch!(loader, model, ps, st, train_state, cfg::TrainConfig)
loss_fn = build_loss_fn(model, cfg)

for (x, y) in loader
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think is better to do cfg.gdev(loader), is just a wrapper at this point, I think. Then, when loop over, the (x,y) will be sent to the gpu. If it works now, we can come back to this later.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works to do cfg.gdev(loader) except for dimensional data, it's still lazy and is not a simple array like the rest.

@Qfl3x
Copy link
Copy Markdown
Collaborator Author

Qfl3x commented Apr 8, 2026

Most things have been addressed. Docs still have problems.

@lazarusA
Copy link
Copy Markdown
Member

lazarusA commented Apr 8, 2026

I suppose docs issues are related to the new prepare data step.

@Qfl3x
Copy link
Copy Markdown
Collaborator Author

Qfl3x commented Apr 8, 2026

/gemini review

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly refactors the data handling and training pipeline to better separate predictors and forcings, and to improve GPU compatibility. Key changes include updating data structures (prepare_data, split_data) to explicitly distinguish between predictors and forcings, integrating Lux's gpu_device and cpu_device for managing device transfers during training and saving, and adapting various training functions (run_epoch!, evaluate_epoch, save_final!) to these new data formats and device considerations. The r2 loss calculation has been redefined, and new utility functions like collect_dim_data and valid_mask are introduced. Review comments highlight critical areas for improvement, including addressing type instability in TrainConfig's device fields, ensuring consistent CPU transfer for model states during evaluation, improving the robustness of data access in to_named_tuple, optimizing data collection in collect_dim_data for efficiency, making the R² loss calculation GPU-friendly, and enhancing semantic clarity in loss computation.

Comment on lines +31 to +35
"Select a gpu_device or default to cpu if none available"
gdev = gpu_device()

"Set the `cpu_device`, useful for sending back to the cpu model parameters"
cdev = cpu_device()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fields gdev and cdev are untyped in the TrainConfig struct. In Julia, untyped fields lead to type instability, which can significantly degrade performance because the compiler cannot specialize functions using these fields. Since these are used frequently for device transfers during training, it is highly recommended to provide type annotations, such as Lux.AbstractDevice or using type parameters.


function to_named_tuple(ka, target_names)
arrays = [Array(ka(variable = k)) for k in target_names]
arrays = [Array(ka[k]) for k in target_names]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The change from ka(variable = k) to ka[k] will break if ka is a KeyedArray, as KeyedArray indexing typically requires dimension names or positional indices. While this might work if ka is now a NamedTuple due to changes in prepare_data, it makes the function less robust if it's still intended to handle KeyedArray inputs.

Comment on lines +2 to +5
x_col = Array(x[1])
forcing_nt = NamedTuple([k => Array(v) for (k, v) in pairs(x[2])])
targets_nt = NamedTuple([k => Array(v) for (k, v) in pairs(y[1])])
masks_nt = NamedTuple([k => Array(v) for (k, v) in pairs(y[2])])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The construction of forcing_nt, targets_nt, and masks_nt using a loop and the NamedTuple constructor is inefficient. Since the inputs are already NamedTuples (as per the changes in prepare_data), you can use map which is more idiomatic and performant in Julia. Additionally, calling Array(v) on every batch is redundant if the data is already on the CPU.

    x_col = x[1]
    forcing_nt = map(identity, x[2])
    targets_nt = map(identity, y[1])
    masks_nt = map(identity, y[2])

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason behind is is if the inputs are DimArrays they need to be collected explicitly, thus the Array calls

function loss_fn(ŷ, y, y_nan, ::Val{:r2})
r = cor(ŷ[y_nan], y[y_nan])
return r * r
return 1 - sum((y[y_nan] .- ŷ[y_nan]) .^ 2) / sum((y[y_nan] .- mean(y[y_nan])) .^ 2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementation of R² using logical indexing (e.g., y[y_nan]) is inefficient on GPUs because it often triggers scalar indexing, which is extremely slow. A more GPU-friendly approach is to use the mask to zero out invalid entries and then perform vectorized reductions. Also, note that this change shifts the definition of R² from squared Pearson correlation to the coefficient of determination, which might be a breaking change for users.

function loss_fn(ŷ, y, y_nan, ::Val{:r2})
    ss_res = sum(abs2, (y .- ŷ) .* y_nan)
    y_mean = sum(y .* y_nan) / sum(y_nan)
    ss_tot = sum(abs2, (y .- y_mean) .* y_nan)
    return 1 - ss_res / ss_tot
end

Comment on lines +110 to +112
y_nan_t = _get_target_y(y_nan, target)
_apply_loss(ŷ_t, y_t, y_nan_t, loss_spec)
# _apply_loss(ŷ_t, y_t, _get_target_nan(y_nan, target), loss_spec)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of _get_target_y for a variable named y_nan is semantically confusing. It is better to use _get_target_nan, which was specifically defined for this purpose. Also, please remove the commented-out code to keep the codebase clean.

                y_nan_t = _get_target_nan(y_nan, target)
                _apply_loss(ŷ_t, y_t, y_nan_t, loss_spec)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants